#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pandas as pd
from numpy import split
from numpy import array
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM


# Define input and output time steps
N_IN = 7
N_OUT = 1

def build_model(train_set):
	# prepare data as sequences
	train_x, train_y = preprocess(train_set)

	# define parameters
	epochs, batch_size = 300, 7

	# define model
	model = Sequential()
	model.add(LSTM(30, activation='relu', input_shape=(7, 1), return_sequences=False))
	model.add(Dense(1))
	model.compile(loss='mse', optimizer='adam')

	# fit model
	val_data = (train_x[-14:, :, :], train_y[-14:, :, :])
	history = model.fit(train_x[0:-14, :, :], train_y[0:-14, :, :],
                     validation_data = val_data, epochs=epochs, batch_size=batch_size, verbose = 2)
	return model, history

# Create training sequences with sliding window
# input:
#       train_set:
# output:
#         X:[samples, input time step, feature]
#         y:[samples, output time step, feature]
def preprocess(train_set):

	data = train_set.reshape((train_set.shape[0]*train_set.shape[1], train_set.shape[2]))
	X, y = list(), list()
	in_bgn = 0

	while True:
		in_end = in_bgn + N_IN
		out_end = in_end + N_OUT

		if out_end > len(data):
			break
		x_in = data[in_bgn:in_end, 0]
		X.append(x_in)
		y.append(data[in_end:out_end, 0])
		in_bgn += 1

    # reshape input & output into [samples, timesteps, features]
	X = array(X)
	y = array(y)
	X = X.reshape(X.shape[0], X.shape[1], 1)
	y = y.reshape(y.shape[0], y.shape[1], 1)
	return X, y

# input:
#       model: the pre-trained model
#       test: 7-day input data
# output:
#       yhat: the 7-day predictions
def predict(model, test):
	# flatten data
	data = array(test)

	# reshape into [1, n_input, 1]
	data = data.reshape((1, data.shape[0]*data.shape[1], 1))

	# forecast the next week
	yhat = model.predict(data, verbose=0)
	yhat = yhat[0]
	return yhat


def run(train_set, test_set):
	# build the model
	model, his = build_model(train_set)

	# walk-forward validation over each week
	predictions = list()

	for i in range(len(test_set)-7):
		# predict next day based on data from the past week
		next_day = predict(model, test_set[i:i+7])

		# store the predictions
		predictions.append(next_day)

	predictions = array(predictions)

	return predictions, his

# Import data and split the training & test sets
data = pd.read_excel("/Users/urielyang/OneDrive - Emory University/Honors/Artificial_Q.xlsx")

train_set = data.iloc[0:56, 0].values.reshape(56, 1)
test_set = data.iloc[0:80, 0].values.reshape(80, 1)
train_set = array(split(train_set, len(train_set)/1))
test_set = array(split(test_set, len(test_set)/1))

# Run
prediction, history = run(train_set, test_set)

# Plot predictions
actual = test_set.reshape(test_set.shape[0]*test_set.shape[1])
temp = prediction.reshape(prediction.shape[0]*prediction.shape[1])
pyplot.plot(range(8,81), temp, label='prediction', color = 'red')
pyplot.plot(range(1,81), actual, label='recovered', color = 'black')
#pyplot.plot(data.iloc[0:98, 1], label='actual', color = 'blue')
pyplot.axvline(x=40, color='k', linestyle='--')
pyplot.axvline(x=56, color='k', linestyle='--')
pyplot.text(x=30, y=0.95, s='Training Set', color='k', fontsize=15)
pyplot.text(x=40.5, y=0.95, s='Validation Set', color='k', fontsize=15)
pyplot.text(x=56.5, y=0.95, s='Test Set', color='k', fontsize=15)
#pyplot.legend(['predicted', 'recovered', 'actual'], loc='upper left')
pyplot.legend(['predicted', 'recovered'], loc='upper left')
pyplot.xlabel('day')
pyplot.show()

pyplot.plot(range(8,81), temp, label='prediction', color = 'red')
pyplot.plot(range(1,81), actual, label='recovered', color = 'black')
pyplot.plot(range(1,81), data.iloc[0:80, 1], label='actual', color = 'blue')
pyplot.axvline(x=40, color='k', linestyle='--')
pyplot.axvline(x=56, color='k', linestyle='--')
pyplot.text(x=30, y=0.35, s='Training Set', color='k', fontsize=15)
pyplot.text(x=40.5, y=0.35, s='Validation Set', color='k', fontsize=15)
pyplot.text(x=56.5, y=0.45, s='Test Set', color='k', fontsize=15)
pyplot.legend(['predicted', 'recovered', 'actual'], loc='upper left')
pyplot.xlabel('day')
pyplot.show()



# Plot training & validation error
pyplot.plot(history.history['loss'])
pyplot.plot(history.history['val_loss'])
pyplot.ylabel('loss (MSE)')
pyplot.xlabel('epoch')
pyplot.legend(['train', 'validation'], loc='upper left')
pyplot.show()